nlp_architect.nn.torch package

Submodules

nlp_architect.nn.torch.distillation module

class nlp_architect.nn.torch.distillation.TeacherStudentDistill(teacher_model: nlp_architect.models.TrainableModel, temperature: float = 1.0, dist_w: float = 0.1, loss_w: float = 1.0, loss_function='kl')[source]

Bases: object

Teacher-Student knowledge distillation helper. Use this object when training a model with KD and a teacher model.

Parameters:
  • teacher_model (TrainableModel) – teacher model
  • temperature (float, optional) – KD temperature. Defaults to 1.0.
  • dist_w (float, optional) – distillation loss weight. Defaults to 0.1.
  • loss_w (float, optional) – student loss weight. Defaults to 1.0.
  • loss_function (str, optional) – loss function to use (kl for KLDivLoss, mse for MSELoss)
static add_args(parser: argparse.ArgumentParser)[source]

Add KD arguments to parser

Parameters:parser (argparse.ArgumentParser) – parser
distill_loss(loss, student_logits, teacher_logits)[source]

Add KD loss

Parameters:
  • loss – student loss
  • student_logits – student model logits
  • teacher_logits – teacher model logits
Returns:

KD loss

get_teacher_logits(inputs)[source]

Get teacher logits

Parameters:inputs – input
Returns:teachr logits

nlp_architect.nn.torch.quantization module

Quantization ops

class nlp_architect.nn.torch.quantization.FakeLinearQuantizationWithSTE[source]

Bases: torch.autograd.function.Function

Simulates error caused by quantization. Uses Straight-Through Estimator for Back prop

static backward(ctx, grad_output)[source]

Calculate estimated gradients for fake quantization using Straight-Through Estimator (STE) according to: https://openreview.net/pdf?id=B1ae1lZRb

static forward(ctx, input, scale, bits=8)[source]

fake quantize input according to scale and number of bits, dequantize quantize(input))

class nlp_architect.nn.torch.quantization.QuantizationConfig(activation_bits=8, weight_bits=8, mode='none', start_step=0, ema_decay=0.9999, requantize_output=True)[source]

Bases: nlp_architect.common.config.Config

Quantization Configuration Object

class nlp_architect.nn.torch.quantization.QuantizationMode[source]

Bases: enum.Enum

An enumeration.

DYNAMIC = 2
EMA = 3
NONE = 1
class nlp_architect.nn.torch.quantization.QuantizedEmbedding(*args, weight_bits=8, start_step=0, mode='none', **kwargs)[source]

Bases: torch.nn.modules.sparse.Embedding

Embedding layer with quantization aware training capability

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

fake_quantized_weight
forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_config(*args, config=None, **kwargs)[source]

Initialize quantized layer from config

quantized_forward(input)[source]

Return quantized embeddings

weight_scale
class nlp_architect.nn.torch.quantization.QuantizedLinear(*args, activation_bits=8, weight_bits=8, requantize_output=True, ema_decay=0.9999, start_step=0, mode='none', **kwargs)[source]

Bases: torch.nn.modules.linear.Linear

Linear layer with quantization aware training capability

extra_repr()[source]

Set the extra representation of the module

To print customized extra information, you should reimplement this method in your own modules. Both single-line and multi-line strings are acceptable.

fake_quantized_weight
forward(input)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

classmethod from_config(*args, config=None, **kwargs)[source]

Initialize quantized layer from config

get_quantized_bias(scale)[source]
inference_quantized_forward(input)[source]

Simulate quantized inference. quantize input and perform calculation with only integer numbers. This function should only be used while doing inference

quantized_weight
training_quantized_forward(input)[source]

fake quantized forward, fake quantizes weights and activations, learn quantization ranges if quantization mode is EMA. This function should only be used while training

weight_scale
nlp_architect.nn.torch.quantization.calc_max_quant_value(bits)[source]

Calculate the maximum symmetric quantized value according to number of bits

nlp_architect.nn.torch.quantization.dequantize(input, scale)[source]

linear dequantization according to some scale

nlp_architect.nn.torch.quantization.get_dynamic_scale(x, bits, with_grad=False)[source]

Calculate dynamic scale for quantization from input by taking the maximum absolute value from x and number of bits

nlp_architect.nn.torch.quantization.get_scale(bits, threshold)[source]

Calculate scale for quantization according to some constant and number of bits

nlp_architect.nn.torch.quantization.quantize(input, scale, bits)[source]

Do linear quantization to input according to a scale and number of bits

Module contents

nlp_architect.nn.torch.set_seed(seed, n_gpus=None)[source]

set seed

nlp_architect.nn.torch.setup_backend(no_cuda)[source]

Setup backend according to selected backend and detected configuration